{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 搭建 Bi-GRU (双端GRU)网络实现 MNIST 分类\n",
    "\n",
    "在上个例子中，我们已经理解了在 TensorFlow 中如何来实现多层的 LSTM 和 GRU。现在我们来实现一下更加常用的 Bi-GRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use the retry module or similar alternatives.\n",
      "WARNING:tensorflow:From <ipython-input-1-fbc070ee2f3a>:15: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please write your own downloading logic.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting ../data/MNIST_data/train-images-idx3-ubyte.gz\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting ../data/MNIST_data/train-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.one_hot on tensors.\n",
      "Extracting ../data/MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting ../data/MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "(10000, 10)\n",
      "(55000, 10)\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')  # 不打印 warning \n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "# 设置GPU按需增长\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "# 用tensorflow 导入数据\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "mnist = input_data.read_data_sets('../data/MNIST_data', one_hot=True) \n",
    "\n",
    "# 看看咱们样本的数量\n",
    "print(mnist.test.labels.shape)\n",
    "print(mnist.train.labels.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ** 一、首先设置好模型用到的各个超参数 **"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1e-3\n",
    "input_size = 28      # 每个时刻的输入特征是28维的，就是每个时刻输入一行，一行有 28 个像素\n",
    "timestep_size = 28   # 时序持续长度为28，即每做一次预测，需要先输入28行\n",
    "hidden_size = 256    # 隐含层的数量\n",
    "layer_num = 2        # LSTM layer 的层数\n",
    "class_num = 10       # 最后输出分类类别数量，如果是回归预测的话应该是 1\n",
    "cell_type = \"block_gru\"   # gru 或者 block_gru\n",
    "\n",
    "X_input = tf.placeholder(tf.float32, [None, 784])\n",
    "y_input = tf.placeholder(tf.float32, [None, class_num])\n",
    "# 在训练和测试的时候，我们想用不同的 batch_size.所以采用占位符的方式\n",
    "batch_size = tf.placeholder(tf.int32, [])  # 注意类型必须为 tf.int32, batch_size = 128\n",
    "keep_prob = tf.placeholder(tf.float32, [])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ** 二、开始搭建 GRU 模型，和 LSTM 模型基本一致 **"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(\"stack_bidirectional_rnn/cell_1/concat:0\", shape=(?, 28, 512), dtype=float32)\n",
      "Tensor(\"Mean:0\", shape=(?, 512), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "# 把784个点的字符信息还原成 28 * 28 的图片\n",
    "# 下面几个步骤是实现 RNN / gru 的关键\n",
    "\n",
    "# **步骤1：RNN 的输入shape = (batch_size, timestep_size, input_size) \n",
    "X = tf.reshape(X_input, [-1, 28, 28])\n",
    "\n",
    "# ** 步骤2：创建 gru 结构\n",
    "def gru_cell(cell_type, num_nodes, keep_prob):\n",
    "    assert(cell_type in [\"gru\", \"block_gru\"], \"Wrong cell type.\")\n",
    "    if cell_type == \"gru\":\n",
    "        cell = tf.contrib.rnn.GRUCell(num_nodes)\n",
    "    else:\n",
    "        cell = tf.contrib.rnn.GRUBlockCellV2(num_nodes)\n",
    "    cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=keep_prob)\n",
    "    return cell\n",
    "\n",
    "cells_fw = [gru_cell(cell_type, hidden_size, keep_prob) for _ in range(layer_num)]\n",
    "cells_bw = [gru_cell(cell_type, hidden_size, keep_prob) for _ in range(layer_num)]\n",
    "\n",
    "outputs, _, _ = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(cells_fw=cells_fw,  cells_bw=cells_bw, inputs=X, dtype=tf.float32)\n",
    "\n",
    "print(outputs)  # shape=(?, 28, 512)\n",
    "# h_state = outputs[:, -1, :]\n",
    "# print(h_state)\n",
    "\n",
    "# 之前的例子中我们都取最后一个时间步作为输出，现在我们取每个时间步的平均值作为特征输出\n",
    "h_state = tf.reduce_mean(axis=1, input_tensor=outputs, keepdims=False)\n",
    "print(h_state)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ** 三、最后设置 loss function 和 优化器，展开训练并完成测试 **"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 500, train cost=0.018590, acc=0.960000; test cost=0.010676, acc=0.964100; pass 48.847047090530396s\n",
      "step 1000, train cost=0.003681, acc=0.980000; test cost=0.007433, acc=0.976600; pass 47.62926483154297s\n",
      "step 1500, train cost=0.009871, acc=0.960000; test cost=0.006339, acc=0.980500; pass 46.96517825126648s\n",
      "step 2000, train cost=0.013135, acc=0.960000; test cost=0.004675, acc=0.985000; pass 46.62891340255737s\n",
      "step 2500, train cost=0.000233, acc=1.000000; test cost=0.005693, acc=0.982500; pass 46.65591073036194s\n",
      "step 3000, train cost=0.001926, acc=0.990000; test cost=0.005315, acc=0.984700; pass 46.85902285575867s\n",
      "step 3500, train cost=0.010246, acc=0.980000; test cost=0.003928, acc=0.988100; pass 47.245213747024536s\n",
      "step 4000, train cost=0.003431, acc=0.990000; test cost=0.004989, acc=0.985800; pass 46.78474521636963s\n",
      "step 4500, train cost=0.000430, acc=1.000000; test cost=0.003681, acc=0.989100; pass 46.685664653778076s\n",
      "step 5000, train cost=0.000663, acc=1.000000; test cost=0.003303, acc=0.990400; pass 46.5777542591095s\n"
     ]
    }
   ],
   "source": [
    "import time \n",
    "\n",
    "# 开始训练和测试\n",
    "W = tf.Variable(tf.truncated_normal([hidden_size*2, class_num], stddev=0.1), dtype=tf.float32)\n",
    "bias = tf.Variable(tf.constant(0.1,shape=[class_num]), dtype=tf.float32)\n",
    "y_pre = tf.nn.softmax(tf.matmul(h_state, W) + bias)\n",
    "\n",
    "\n",
    "# 损失和评估函数\n",
    "cross_entropy = -tf.reduce_mean(y_input * tf.log(y_pre))\n",
    "train_op = tf.train.AdamOptimizer(lr).minimize(cross_entropy)\n",
    "\n",
    "correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(y_input,1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, \"float\"))\n",
    "\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "time0 = time.time()\n",
    "for i in range(5000):\n",
    "    _batch_size=100\n",
    "    X_batch, y_batch = mnist.train.next_batch(batch_size=_batch_size)\n",
    "    cost, acc,  _ = sess.run([cross_entropy, accuracy, train_op], feed_dict={X_input: X_batch, y_input: y_batch, keep_prob: 0.5, batch_size: _batch_size})\n",
    "    if (i+1) % 500 == 0:\n",
    "        # 分 100 个batch 迭代\n",
    "        test_acc = 0.0\n",
    "        test_cost = 0.0\n",
    "        N = 100\n",
    "        for j in range(N):\n",
    "            X_batch, y_batch = mnist.test.next_batch(batch_size=_batch_size)\n",
    "            _cost, _acc = sess.run([cross_entropy, accuracy], feed_dict={X_input: X_batch, y_input: y_batch, keep_prob: 1.0, batch_size: _batch_size})\n",
    "            test_acc += _acc\n",
    "            test_cost += _cost\n",
    "        print(\"step {}, train cost={:.6f}, acc={:.6f}; test cost={:.6f}, acc={:.6f}; pass {}s\".format(i+1, cost, acc, test_cost/N, test_acc/N, time.time() - time0))\n",
    "        time0 = time.time()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "一般来说，这种序列分类的任务中，Bi-GRU 的结果都会比单向的要好一些，只是这个 MNIST 分类的任务比较简单，也没能够体现出明显的优势。但是和上一个实验比较，你就会发现，使用 Bi-GRU 的速度明显要比单向的慢呀，而且慢得很多。\n",
    "\n",
    "下面统计了不同模型在 MNIST 上面的速度，每迭代 500 个 training batch(batch_size=100) 和 100 个 testing batch 所花费的时间(second).因为准确率都差不多，就不写了。每个模型都是在没有其他任务的情况下跑的，所以时间还是比较有参考意义的。\n",
    "\n",
    "|Model|1-layer|2-layer|\n",
    "|:----:|:---:|:---:|\n",
    "|LSTM|6.3|13.2|\n",
    "|GRU|4.8|10.2|\n",
    "|Bi-GRU|24.0|46.7|\n",
    "|CNN|-|2.4|\n",
    "\n",
    "从上面的结果可以看到：\n",
    "- CNN 的速度相对 RNN 来说要快不少\n",
    "- GRU 要比 LSTM 稍微快一些\n",
    "- Bi-GRU 比 GRU 要慢好多好多\n",
    "\n",
    "为什么一层的 BIGRU 都能慢这么多，这个还真没搞清楚。"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
